datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.ipc.SeekableReadChannel;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.*;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.conversion.TypeConversion;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.*;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.serde.binary.BinarySerde;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import static java.nio.channels.Channels.newChannel;
@Slf4j
public class ArrowConverter {
/**
* Create an ndarray from a matrix.
* The included batch must be all the same number of rows in order
* to work. The reason for this is {@link INDArray} must be all the same dimensions.
* Note that the input columns must also be numerical. If they aren't numerical already,
* consider using an {@link org.datavec.api.transform.TransformProcess} to transform the data
* output from {@link org.datavec.arrow.recordreader.ArrowRecordReader} in to the proper format
* for usage with this method for direct conversion.
*
* @param arrowWritableRecordBatch the incoming batch. This is typically output from
* an {@link org.datavec.arrow.recordreader.ArrowRecordReader}
* @return an {@link INDArray} representative of the input data
*/
public static INDArray toArray(ArrowWritableRecordTimeSeriesBatch arrowWritableRecordBatch) {
return RecordConverter.toTensor(arrowWritableRecordBatch);
}
/**
* Create an ndarray from a matrix.
* The included batch must be all the same number of rows in order
* to work. The reason for this is {@link INDArray} must be all the same dimensions.
* Note that the input columns must also be numerical. If they aren't numerical already,
* consider using an {@link org.datavec.api.transform.TransformProcess} to transform the data
* output from {@link org.datavec.arrow.recordreader.ArrowRecordReader} in to the proper format
* for usage with this method for direct conversion.
*
* @param arrowWritableRecordBatch the incoming batch. This is typically output from
* an {@link org.datavec.arrow.recordreader.ArrowRecordReader}
* @return an {@link INDArray} representative of the input data
*/
public static INDArray toArray(ArrowWritableRecordBatch arrowWritableRecordBatch) {
List<FieldVector> columnVectors = arrowWritableRecordBatch.getList();
Schema schema = arrowWritableRecordBatch.getSchema();
for(int i = 0; i < schema.numColumns(); i++) {
switch(schema.getType(i)) {
case Integer:
break;
case Float:
break;
case Double:
break;
case Long:
break;
case NDArray:
break;
default:
throw new ND4JIllegalArgumentException("Illegal data type found for column " + schema.getName(i) + " of type " + schema.getType(i));
}
}
int rows = arrowWritableRecordBatch.getList().get(0).getValueCount();
if(schema.numColumns() == 1 && schema.getMetaData(0).getColumnType() == ColumnType.NDArray) {
INDArray[] toConcat = new INDArray[rows];
VarBinaryVector valueVectors = (VarBinaryVector) arrowWritableRecordBatch.getList().get(0);
for(int i = 0; i < rows; i++) {
byte[] bytes = valueVectors.get(i);
ByteBuffer direct = ByteBuffer.allocateDirect(bytes.length);
direct.put(bytes);
INDArray fromTensor = BinarySerde.toArray(direct);
toConcat[i] = fromTensor;
}
return Nd4j.concat(0,toConcat);
}
int cols = schema.numColumns();
INDArray arr = Nd4j.create(rows,cols);
for(int i = 0; i < cols; i++) {
INDArray put = ArrowConverter.convertArrowVector(columnVectors.get(i),schema.getType(i));
switch(arr.data().dataType()) {
case FLOAT:
arr.putColumn(i,Nd4j.create(put.data().asFloat()).reshape(rows,1));
break;
case DOUBLE:
arr.putColumn(i,Nd4j.create(put.data().asDouble()).reshape(rows,1));
break;
}
}
return arr;
}
/**
* Convert a field vector to a column vector
* @param fieldVector the field vector to convert
* @param type the type of the column vector
* @return the converted ndarray
*/
public static INDArray convertArrowVector(FieldVector fieldVector,ColumnType type) {
DataBuffer buffer = null;
int cols = fieldVector.getValueCount();
ByteBuffer direct = ByteBuffer.allocateDirect((int) fieldVector.getDataBuffer().capacity());
direct.order(ByteOrder.nativeOrder());
fieldVector.getDataBuffer().getBytes(0,direct);
Buffer buffer1 = (Buffer) direct;
buffer1.rewind();
switch(type) {
case Integer:
buffer = Nd4j.createBuffer(direct, DataType.INT,cols,0);
break;
case Float:
buffer = Nd4j.createBuffer(direct, DataType.FLOAT,cols);
break;
case Double:
buffer = Nd4j.createBuffer(direct, DataType.DOUBLE,cols);
break;
case Long:
buffer = Nd4j.createBuffer(direct, DataType.LONG,cols);
break;
}
return Nd4j.create(buffer,new int[] {cols,1});
}
/**
* Convert an {@link INDArray}
* to a list of column vectors or a singleton
* list when either a row vector or a column vector
* @param from the input array
* @param name the name of the vector
* @param type the type of the vector
* @param bufferAllocator the allocator to use
* @return the list of field vectors
*/
public static List<FieldVector> convertToArrowVector(INDArray from,List<String> name,ColumnType type,BufferAllocator bufferAllocator) {
List<FieldVector> ret = new ArrayList<>();
if(from.isVector()) {
long cols = from.length();
switch(type) {
case Double:
double[] fromData = from.isView() ? from.dup().data().asDouble() : from.data().asDouble();
ret.add(vectorFor(bufferAllocator,name.get(0),fromData));
break;
case Float:
float[] fromDataFloat = from.isView() ? from.dup().data().asFloat() : from.data().asFloat();
ret.add(vectorFor(bufferAllocator,name.get(0),fromDataFloat));
break;
case Integer:
int[] fromDataInt = from.isView() ? from.dup().data().asInt() : from.data().asInt();
ret.add(vectorFor(bufferAllocator,name.get(0),fromDataInt));
break;
default:
throw new IllegalArgumentException("Illegal type " + type);
}
}
else {
long cols = from.size(1);
for(int i = 0; i < cols; i++) {
INDArray column = from.getColumn(i);
switch(type) {
case Double:
double[] fromData = column.isView() ? column.dup().data().asDouble() : from.data().asDouble();
ret.add(vectorFor(bufferAllocator,name.get(i),fromData));
break;
case Float:
float[] fromDataFloat = column.isView() ? column.dup().data().asFloat() : from.data().asFloat();
ret.add(vectorFor(bufferAllocator,name.get(i),fromDataFloat));
break;
case Integer:
int[] fromDataInt = column.isView() ? column.dup().data().asInt() : from.data().asInt();
ret.add(vectorFor(bufferAllocator,name.get(i),fromDataInt));
break;
default:
throw new IllegalArgumentException("Illegal type " + type);
}
}
}
return ret;
}
/**
* Write the records to the given output stream
* @param recordBatch the record batch to write
* @param inputSchema the input schema
* @param outputStream the output stream to write to
*/
public static void writeRecordBatchTo(List<List<Writable>> recordBatch, Schema inputSchema,OutputStream outputStream) {
BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
writeRecordBatchTo(bufferAllocator,recordBatch,inputSchema,outputStream);
}
/**
* Write the records to the given output stream
* @param recordBatch the record batch to write
* @param inputSchema the input schema
* @param outputStream the output stream to write to
*/
public static void writeRecordBatchTo(BufferAllocator bufferAllocator ,List<List<Writable>> recordBatch, Schema inputSchema,OutputStream outputStream) {
val convertedSchema = toArrowSchema(inputSchema);
List<FieldVector> columns = toArrowColumns(bufferAllocator,inputSchema,recordBatch);
try (VectorSchemaRoot root = new VectorSchemaRoot(convertedSchema,columns,recordBatch.size());
ArrowFileWriter writer = new ArrowFileWriter(root, providerForVectors(columns,convertedSchema.getFields()), newChannel(outputStream))) {
writer.start();
writer.writeBatch();
writer.end();
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
/**
* Convert the input field vectors (the input data) and
* the given schema to a proper list of writables.
* @param fieldVectors the field vectors to use
* @param schema the schema to use
* @param timeSeriesLength the length of the time series
* @return the equivalent datavec batch given the input data
*/
public static List<List<List<Writable>>> toArrowWritablesTimeSeries(List<FieldVector> fieldVectors,Schema schema,int timeSeriesLength) {
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,schema,timeSeriesLength);
return arrowWritableRecordBatch;
}
/**
* Convert the input field vectors (the input data) and
* the given schema to a proper list of writables.
* @param fieldVectors the field vectors to use
* @param schema the schema to use
* @return the equivalent datavec batch given the input data
*/
public static ArrowWritableRecordBatch toArrowWritables(List<FieldVector> fieldVectors,Schema schema) {
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors,schema);
return arrowWritableRecordBatch;
}
/**
* Return a singular record based on the converted
* writables result.
* @param fieldVectors the field vectors to use
* @param schema the schema to use for input
* @return
*/
public static List<Writable> toArrowWritablesSingle(List<FieldVector> fieldVectors,Schema schema) {
return toArrowWritables(fieldVectors,schema).get(0);
}
/**
* Read a datavec schema and record set
* from the given arrow file.
* @param input the input to read
* @return the associated datavec schema and record
*/
public static Pair<Schema,ArrowWritableRecordBatch> readFromFile(FileInputStream input) throws IOException {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
Schema retSchema = null;
ArrowWritableRecordBatch ret = null;
SeekableReadChannel channel = new SeekableReadChannel(input.getChannel());
ArrowFileReader reader = new ArrowFileReader(channel, allocator);
reader.loadNextBatch();
retSchema = toDatavecSchema(reader.getVectorSchemaRoot().getSchema());
//load the batch
VectorUnloader unloader = new VectorUnloader(reader.getVectorSchemaRoot());
VectorLoader vectorLoader = new VectorLoader(reader.getVectorSchemaRoot());
ArrowRecordBatch recordBatch = unloader.getRecordBatch();
vectorLoader.load(recordBatch);
ret = asDataVecBatch(recordBatch,retSchema,reader.getVectorSchemaRoot());
ret.setUnloader(unloader);
return Pair.of(retSchema,ret);
}
/**
* Read a datavec schema and record set
* from the given arrow file.
* @param input the input to read
* @return the associated datavec schema and record
*/
public static Pair<Schema,ArrowWritableRecordBatch> readFromFile(File input) throws IOException {
return readFromFile(new FileInputStream(input));
}
/**
* Read a datavec schema and record set
* from the given bytes (usually expected to be an arrow format file)
* @param input the input to read
* @return the associated datavec schema and record
*/
public static Pair<Schema,ArrowWritableRecordBatch> readFromBytes(byte[] input) throws IOException {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
Schema retSchema = null;
ArrowWritableRecordBatch ret = null;
SeekableReadChannel channel = new SeekableReadChannel(new ByteArrayReadableSeekableByteChannel(input));
ArrowFileReader reader = new ArrowFileReader(channel, allocator);
reader.loadNextBatch();
retSchema = toDatavecSchema(reader.getVectorSchemaRoot().getSchema());
//load the batch
VectorUnloader unloader = new VectorUnloader(reader.getVectorSchemaRoot());
VectorLoader vectorLoader = new VectorLoader(reader.getVectorSchemaRoot());
ArrowRecordBatch recordBatch = unloader.getRecordBatch();
vectorLoader.load(recordBatch);
ret = asDataVecBatch(recordBatch,retSchema,reader.getVectorSchemaRoot());
ret.setUnloader(unloader);
return Pair.of(retSchema,ret);
}
/**
* Convert a data vec {@link Schema}
* to an arrow {@link org.apache.arrow.vector.types.pojo.Schema}
* @param schema the input schema
* @return the schema for arrow
*/
public static org.apache.arrow.vector.types.pojo.Schema toArrowSchema(Schema schema) {
List<Field> fields = new ArrayList<>(schema.numColumns());
for(int i = 0; i < schema.numColumns(); i++) {
fields.add(getFieldForColumn(schema.getName(i),schema.getType(i)));
}
return new org.apache.arrow.vector.types.pojo.Schema(fields);
}
/**
* Convert an {@link org.apache.arrow.vector.types.pojo.Schema}
* to a datavec {@link Schema}
* @param schema the input arrow schema
* @return the equivalent datavec schema
*/
public static Schema toDatavecSchema(org.apache.arrow.vector.types.pojo.Schema schema) {
Schema.Builder schemaBuilder = new Schema.Builder();
for (int i = 0; i < schema.getFields().size(); i++) {
schemaBuilder.addColumn(metaDataFromField(schema.getFields().get(i)));
}
return schemaBuilder.build();
}
/**
* Shortcut method for returning a field
* given an arrow type and name
* with no sub fields
* @param name the name of the field
* @param arrowType the arrow type of the field
* @return the resulting field
*/
public static Field field(String name,ArrowType arrowType) {
return new Field(name,FieldType.nullable(arrowType), new ArrayList<Field>());
}
/**
* Create a field given the input {@link ColumnType}
* and name
* @param name the name of the field
* @param columnType the column type to add
* @return
*/
public static Field getFieldForColumn(String name,ColumnType columnType) {
switch(columnType) {
case Long: return field(name,new ArrowType.Int(64,false));
case Integer: return field(name,new ArrowType.Int(32,false));
case Double: return field(name,new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE));
case Float: return field(name,new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE));
case Boolean: return field(name, new ArrowType.Bool());
case Categorical: return field(name,new ArrowType.Utf8());
case Time: return field(name,new ArrowType.Date(DateUnit.MILLISECOND));
case Bytes: return field(name,new ArrowType.Binary());
case NDArray: return field(name,new ArrowType.Binary());
case String: return field(name,new ArrowType.Utf8());
default: throw new IllegalArgumentException("Column type invalid " + columnType);
}
}
/**
* Shortcut method for creating a double field
* with 64 bit floating point
* @param name the name of the field
* @return the created field
*/
public static Field doubleField(String name) {
return getFieldForColumn(name, ColumnType.Double);
}
/**
* Shortcut method for creating a double field
* with 32 bit floating point
* @param name the name of the field
* @return the created field
*/
public static Field floatField(String name) {
return getFieldForColumn(name,ColumnType.Float);
}
/**
* Shortcut method for creating a double field
* with 32 bit integer field
* @param name the name of the field
* @return the created field
*/
public static Field intField(String name) {
return getFieldForColumn(name,ColumnType.Integer);
}
/**
* Shortcut method for creating a long field
* with 64 bit long field
* @param name the name of the field
* @return the created field
*/
public static Field longField(String name) {
return getFieldForColumn(name,ColumnType.Long);
}
/**
*
* @param name
* @return
*/
public static Field stringField(String name) {
return getFieldForColumn(name,ColumnType.String);
}
/**
* Shortcut
* @param name
* @return
*/
public static Field booleanField(String name) {
return getFieldForColumn(name,ColumnType.Boolean);
}
/**
* Provide a value look up dictionary based on the
* given set of input {@link FieldVector} s for
* reading and writing to arrow streams
* @param vectors the vectors to use as a lookup
* @return the associated {@link DictionaryProvider} for the given
* input {@link FieldVector} list
*/
public static DictionaryProvider providerForVectors(List<FieldVector> vectors,List<Field> fields) {
Dictionary[] dictionaries = new Dictionary[vectors.size()];
for(int i = 0; i < vectors.size(); i++) {
DictionaryEncoding dictionary = fields.get(i).getDictionary();
if(dictionary == null) {
dictionary = new DictionaryEncoding(i,true,null);
}
dictionaries[i] = new Dictionary(vectors.get(i), dictionary);
}
return new DictionaryProvider.MapDictionaryProvider(dictionaries);
}
/**
* Given a buffer allocator and datavec schema,
* convert the passed in batch of records
* to a set of arrow columns
* @param bufferAllocator the buffer allocator to use
* @param schema the schema to convert
* @param dataVecRecord the data vec record batch to convert
* @return the converted list of {@link FieldVector}
*/
public static List<FieldVector> toArrowColumns(final BufferAllocator bufferAllocator, final Schema schema, List<List<Writable>> dataVecRecord) {
int numRows = dataVecRecord.size();
List<FieldVector> ret = createFieldVectors(bufferAllocator,schema,numRows);
for(int j = 0; j < schema.numColumns(); j++) {
FieldVector fieldVector = ret.get(j);
int row = 0;
for(List<Writable> record : dataVecRecord) {
Writable writable = record.get(j);
setValue(schema.getType(j),fieldVector,writable,row);
row++;
}
}
return ret;
}
/**
* Convert a set of input strings to arrow columns
* for a time series.
* @param bufferAllocator the buffer allocator to use
* @param schema the schema to use
* @param dataVecRecord the collection of input strings to process
* @return the created vectors
*/
public static List<FieldVector> toArrowColumnsTimeSeries(final BufferAllocator bufferAllocator,
final Schema schema,
List<List<List<Writable>>> dataVecRecord) {
return toArrowColumnsTimeSeriesHelper(bufferAllocator,schema,dataVecRecord);
}
/**
* Convert a set of input strings to arrow columns
* for a time series.
* @param bufferAllocator the buffer allocator to use
* @param schema the schema to use
* @param dataVecRecord the collection of input strings to process
* @return the created vectors
*/
public static <T> List<FieldVector> toArrowColumnsTimeSeriesHelper(final BufferAllocator bufferAllocator,
final Schema schema,
List<List<List<T>>> dataVecRecord) {
//time series length * number of columns
int numRows = 0;
for(List<List<T>> timeStep : dataVecRecord) {
numRows += timeStep.get(0).size() * timeStep.size();
}
numRows /= schema.numColumns();
List<FieldVector> ret = createFieldVectors(bufferAllocator,schema,numRows);
Map<Integer,Integer> currIndex = new HashMap<>(ret.size());
for(int i = 0; i < ret.size(); i++) {
currIndex.put(i,0);
}
for(int i = 0; i < dataVecRecord.size(); i++) {
List<List<T>> record = dataVecRecord.get(i);
for(int j = 0; j < record.size(); j++) {
List<T> curr = record.get(j);
for(int k = 0; k < curr.size(); k++) {
Integer idx = currIndex.get(k);
FieldVector fieldVector = ret.get(k);
T writable = curr.get(k);
setValue(schema.getType(k), fieldVector, writable, idx);
currIndex.put(k,idx + 1);
}
}
}
return ret;
}
/**
* Convert a set of input strings to arrow columns
* @param bufferAllocator the buffer allocator to use
* @param schema the schema to use
* @param dataVecRecord the collection of input strings to process
* @return the created vectors
*/
public static List<FieldVector> toArrowColumnsStringSingle(final BufferAllocator bufferAllocator, final Schema schema, List<String> dataVecRecord) {
return toArrowColumnsString(bufferAllocator,schema, Arrays.asList(dataVecRecord));
}
/**
* Convert a set of input strings to arrow columns
* for a time series.
* @param bufferAllocator the buffer allocator to use
* @param schema the schema to use
* @param dataVecRecord the collection of input strings to process
* @return the created vectors
*/
public static List<FieldVector> toArrowColumnsStringTimeSeries(final BufferAllocator bufferAllocator,
final Schema schema,
List<List<List<String>>> dataVecRecord) {
return toArrowColumnsTimeSeriesHelper(bufferAllocator,schema,dataVecRecord);
}
/**
* Convert a set of input strings to arrow columns
* @param bufferAllocator the buffer allocator to use
* @param schema the schema to use
* @param dataVecRecord the collection of input strings to process
* @return the created vectors
*/
public static List<FieldVector> toArrowColumnsString(final BufferAllocator bufferAllocator, final Schema schema, List<List<String>> dataVecRecord) {
int numRows = dataVecRecord.size();
List<FieldVector> ret = createFieldVectors(bufferAllocator,schema,numRows);
/**
* Need to change iteration scheme
*/
for(int j = 0; j < schema.numColumns(); j++) {
FieldVector fieldVector = ret.get(j);
for(int row = 0; row < numRows; row++) {
String writable = dataVecRecord.get(row).get(j);
setValue(schema.getType(j),fieldVector,writable,row);
}
}
return ret;
}
private static List<FieldVector> createFieldVectors(BufferAllocator bufferAllocator,Schema schema, int numRows) {
List<FieldVector> ret = new ArrayList<>(schema.numColumns());
for(int i = 0; i < schema.numColumns(); i++) {
switch (schema.getType(i)) {
case Integer: ret.add(intVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
case Long: ret.add(longVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
case Double: ret.add(doubleVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
case Float: ret.add(floatVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
case Boolean: ret.add(booleanVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
case String: ret.add(stringVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
case Categorical: ret.add(stringVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i));
}
}
return ret;
}
/**
* Set the value of the specified column vector
* at the specified row based on the given value.
* The value will be converted relative to the specified column type.
* Note that the passed in value may only be a {@link Writable}
* or a {@link String}
* @param columnType the column type of the value
* @param fieldVector the field vector to set
* @param value the value to set ({@link Writable} or {@link String} types)
* @param row the row of the item
*/
public static void setValue(ColumnType columnType,FieldVector fieldVector,Object value,int row) {
if(value instanceof NullWritable) {
return;
}
try {
switch (columnType) {
case Integer:
if (fieldVector instanceof IntVector) {
IntVector intVector = (IntVector) fieldVector;
int set = TypeConversion.getInstance().convertInt(value);
intVector.set(row, set);
} else if (fieldVector instanceof UInt4Vector) {
UInt4Vector uInt4Vector = (UInt4Vector) fieldVector;
int set = TypeConversion.getInstance().convertInt(value);
uInt4Vector.set(row, set);
} else {
throw new UnsupportedOperationException("Illegal type " + fieldVector.getClass() + " for int type");
}
break;
case Float:
Float4Vector float4Vector = (Float4Vector) fieldVector;
float set2 = TypeConversion.getInstance().convertFloat(value);
float4Vector.set(row, set2);
break;
case Double:
double set3 = TypeConversion.getInstance().convertDouble(value);
Float8Vector float8Vector = (Float8Vector) fieldVector;
float8Vector.set(row, set3);
break;
case Long:
if (fieldVector instanceof BigIntVector) {
BigIntVector largeIntVector = (BigIntVector) fieldVector;
largeIntVector.set(row, TypeConversion.getInstance().convertLong(value));
} else if (fieldVector instanceof UInt8Vector) {
UInt8Vector uInt8Vector = (UInt8Vector) fieldVector;
uInt8Vector.set(row, TypeConversion.getInstance().convertLong(value));
} else {
throw new UnsupportedOperationException("Illegal type " + fieldVector.getClass() + " for long type");
}
break;
case Categorical:
case String:
String stringSet = TypeConversion.getInstance().convertString(value);
VarCharVector textVector = (VarCharVector) fieldVector;
textVector.setSafe(row, stringSet.getBytes());
break;
case Time:
//all timestamps are long based, just directly convert it to the super type
long timeSet = TypeConversion.getInstance().convertLong(value);
setLongInTime(fieldVector, row, timeSet);
break;
case NDArray:
NDArrayWritable arr = (NDArrayWritable) value;
VarBinaryVector nd4jArrayVector = (VarBinaryVector) fieldVector;
//slice the databuffer to use only the needed portion of the buffer
//for proper offsets
ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get());
nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity());
case Boolean:
BitVector bitVector = (BitVector) fieldVector;
if(value instanceof Boolean)
bitVector.set(row, (boolean) value ? 1 : 0);
else
bitVector.set(row, ((BooleanWritable) value).get() ? 1 : 0);
break;
}
}catch(Exception e) {
log.warn("Unable to set value at row " + row);
}
}
private static void setLongInTime(FieldVector fieldVector,int index,long value) {
if(fieldVector instanceof TimeStampMilliVector) {
TimeStampMilliVector timeStampMilliVector = (TimeStampMilliVector) fieldVector;
timeStampMilliVector.set(index,value);
}
else if(fieldVector instanceof TimeMilliVector) {
TimeMilliVector timeMilliVector = (TimeMilliVector) fieldVector;
timeMilliVector.set(index,(int) value);
}
else if(fieldVector instanceof TimeStampMicroVector) {
TimeStampMicroVector timeStampMicroVector = (TimeStampMicroVector) fieldVector;
timeStampMicroVector.set(index,value);
}
else if(fieldVector instanceof TimeSecVector) {
TimeSecVector timeSecVector = (TimeSecVector) fieldVector;
timeSecVector.set(index,(int) value);
}
else if(fieldVector instanceof TimeStampMilliVector) {
TimeStampMilliVector timeStampMilliVector = (TimeStampMilliVector) fieldVector;
timeStampMilliVector.set(index,value);
}
else if(fieldVector instanceof TimeStampMilliTZVector) {
TimeStampMilliTZVector timeStampMilliTZVector = (TimeStampMilliTZVector) fieldVector;
timeStampMilliTZVector.set(index, value);
}
else if(fieldVector instanceof TimeStampNanoTZVector) {
TimeStampNanoTZVector timeStampNanoTZVector = (TimeStampNanoTZVector) fieldVector;
timeStampNanoTZVector.set(index,value);
}
else if(fieldVector instanceof TimeStampMicroTZVector) {
TimeStampMicroTZVector timeStampMicroTZVector = (TimeStampMicroTZVector) fieldVector;
timeStampMicroTZVector.set(index,value);
}
else {
throw new UnsupportedOperationException();
}
}
/**
*
* @param allocator
* @param name
* @param data
* @return
*/
public static TimeStampMilliVector vectorFor(BufferAllocator allocator,String name,Date[] data) {
TimeStampMilliVector float4Vector = new TimeStampMilliVector(name,allocator);
float4Vector.allocateNew(data.length);
for(int i = 0; i < data.length; i++) {
float4Vector.setSafe(i,data[i].getTime());
}
float4Vector.setValueCount(data.length);
return float4Vector;
}
/**
*
* @param allocator
* @param name
* @param length the length of the vector
* @return
*/
public static TimeStampMilliVector timeVectorOf(BufferAllocator allocator,String name,int length) {
TimeStampMilliVector float4Vector = new TimeStampMilliVector(name,allocator);
float4Vector.allocateNew(length);
float4Vector.setValueCount(length);
return float4Vector;
}
/**
* Returns a vector representing a tensor view
* of each ndarray.
* Each ndarray will be a "row" represented as a tensor object
* with in the return {@link VarBinaryVector}
* @param bufferAllocator the buffer allocator to use
* @param name the name of the column
* @param data the input arrays
* @return
*/
public static VarBinaryVector vectorFor(BufferAllocator bufferAllocator,String name,INDArray[] data) {
VarBinaryVector ret = new VarBinaryVector(name,bufferAllocator);
ret.allocateNew();
for(int i = 0; i < data.length; i++) {
//slice the databuffer to use only the needed portion of the buffer
//for proper offset
ByteBuffer byteBuffer = BinarySerde.toByteBuffer(data[i]);
ret.set(i,byteBuffer,0,byteBuffer.capacity());
}
return ret;
}
/**
*
* @param allocator
* @param name
* @param data
* @return
*/
public static VarCharVector vectorFor(BufferAllocator allocator,String name,String[] data) {
VarCharVector float4Vector = new VarCharVector(name,allocator);
float4Vector.allocateNew();
for(int i = 0; i < data.length; i++) {
float4Vector.setSafe(i,data[i].getBytes());
}
float4Vector.setValueCount(data.length);
return float4Vector;
}
/**
* Create an ndarray vector that stores structs
* of {@link INDArray}
* based on the {@link org.apache.arrow.flatbuf.Tensor}
* format
* @param allocator the allocator to use
* @param name the name of the vector
* @param length the number of vectors to store
* @return
*/
public static VarBinaryVector ndarrayVectorOf(BufferAllocator allocator,String name,int length) {
VarBinaryVector ret = new VarBinaryVector(name,allocator);
ret.allocateNewSafe();
ret.setValueCount(length);
return ret;
}
/**
*
* @param allocator
* @param name
* @param length the length of the vector
* @return
*/
public static VarCharVector stringVectorOf(BufferAllocator allocator,String name,int length) {
VarCharVector float4Vector = new VarCharVector(name,allocator);
float4Vector.allocateNew();
float4Vector.setValueCount(length);
return float4Vector;
}
/**
*
* @param allocator
* @param name
* @param data
* @return
*/
public static Float4Vector vectorFor(BufferAllocator allocator,String name,float[] data) {
Float4Vector float4Vector = new Float4Vector(name,allocator);
float4Vector.allocateNew(data.length);
for(int i = 0; i < data.length; i++) {
float4Vector.setSafe(i,data[i]);
}
float4Vector.setValueCount(data.length);
return float4Vector;
}
/**
*
* @param allocator
* @param name
* @param length the length of the vector
* @return
*/
public static Float4Vector floatVectorOf(BufferAllocator allocator,String name,int length) {
Float4Vector float4Vector = new Float4Vector(name,allocator);
float4Vector.allocateNew(length);
float4Vector.setValueCount(length);
return float4Vector;
}
/**
*
* @param allocator
* @param name
* @param data
* @return
*/
public static Float8Vector vectorFor(BufferAllocator allocator,String name,double[] data) {
Float8Vector float8Vector = new Float8Vector(name,allocator);
float8Vector.allocateNew(data.length);
for(int i = 0; i < data.length; i++) {
float8Vector.setSafe(i,data[i]);
}
float8Vector.setValueCount(data.length);
return float8Vector;
}
/**
*
* @param allocator
* @param name
* @param length the length of the vector
* @return
*/
public static Float8Vector doubleVectorOf(BufferAllocator allocator,String name,int length) {
Float8Vector float8Vector = new Float8Vector(name,allocator);
float8Vector.allocateNew();
float8Vector.setValueCount(length);
return float8Vector;
}
/**
*
* @param allocator
* @param name
* @param data
* @return
*/
public static BitVector vectorFor(BufferAllocator allocator,String name,boolean[] data) {
BitVector float8Vector = new BitVector(name,allocator);
float8Vector.allocateNew(data.length);
for(int i = 0; i < data.length; i++) {
float8Vector.setSafe(i,data[i] ? 1 : 0);
}
float8Vector.setValueCount(data.length);
return float8Vector;
}
/**
*
* @param allocator
* @param name
* @return
*/
public static BitVector booleanVectorOf(BufferAllocator allocator,String name,int length) {
BitVector float8Vector = new BitVector(name,allocator);
float8Vector.allocateNew(length);
float8Vector.setValueCount(length);
return float8Vector;
}
/**
*
* @param allocator
* @param name
* @param data
* @return
*/
public static IntVector vectorFor(BufferAllocator allocator,String name,int[] data) {
IntVector float8Vector = new IntVector(name,FieldType.nullable(new ArrowType.Int(32,true)),allocator);
float8Vector.allocateNew(data.length);
for(int i = 0; i < data.length; i++) {
float8Vector.setSafe(i,data[i]);
}
float8Vector.setValueCount(data.length);
return float8Vector;
}
/**
*
* @param allocator
* @param name
* @return
*/
public static IntVector intVectorOf(BufferAllocator allocator,String name,int length) {
IntVector float8Vector = new IntVector(name,FieldType.nullable(new ArrowType.Int(32,true)),allocator);
float8Vector.allocateNew(length);
float8Vector.setValueCount(length);
return float8Vector;
}
/**
*
* @param allocator
* @param name
* @param data
* @return
*/
public static BigIntVector vectorFor(BufferAllocator allocator,String name,long[] data) {
BigIntVector float8Vector = new BigIntVector(name,FieldType.nullable(new ArrowType.Int(64,true)),allocator);
float8Vector.allocateNew(data.length);
for(int i = 0; i < data.length; i++) {
float8Vector.setSafe(i,data[i]);
}
float8Vector.setValueCount(data.length);
return float8Vector;
}
/**
*
* @param allocator
* @param name
* @param length the number of rows in the column vector
* @return
*/
public static BigIntVector longVectorOf(BufferAllocator allocator,String name,int length) {
BigIntVector float8Vector = new BigIntVector(name,FieldType.nullable(new ArrowType.Int(64,true)),allocator);
float8Vector.allocateNew(length);
float8Vector.setValueCount(length);
return float8Vector;
}
private static ColumnMetaData metaDataFromField(Field field) {
ArrowType arrowType = field.getFieldType().getType();
if(arrowType instanceof ArrowType.Int) {
val intType = (ArrowType.Int) arrowType;
if(intType.getBitWidth() == 32)
return new IntegerMetaData(field.getName());
else {
return new LongMetaData(field.getName());
}
}
else if(arrowType instanceof ArrowType.Bool) {
return new BooleanMetaData(field.getName());
}
else if(arrowType instanceof ArrowType.FloatingPoint) {
val floatingPointType = (ArrowType.FloatingPoint) arrowType;
if(floatingPointType.getPrecision() == FloatingPointPrecision.DOUBLE)
return new DoubleMetaData(field.getName());
else {
return new FloatMetaData(field.getName());
}
}
else if(arrowType instanceof ArrowType.Binary) {
return new BinaryMetaData(field.getName());
}
else if(arrowType instanceof ArrowType.Utf8) {
return new StringMetaData(field.getName());
}
else if(arrowType instanceof ArrowType.Date) {
return new TimeMetaData(field.getName());
}
else {
throw new IllegalStateException("Illegal type " + field.getFieldType().getType());
}
}
/**
* Based on an input {@link ColumnType}
* get an entry from a {@link FieldVector}
*
* @param item the row of the item to get from the column vector
* @param from the column vector from
* @param columnType the column type
* @return the resulting writable
*/
public static Writable fromEntry(int item,FieldVector from,ColumnType columnType) {
if(from.getValueCount() < item) {
throw new IllegalArgumentException("Index specified greater than the number of items in the vector with length " + from.getValueCount());
}
switch(columnType) {
case Integer:
return new IntWritable(getIntFromFieldVector(item,from));
case Long:
return new LongWritable(getLongFromFieldVector(item,from));
case Float:
return new FloatWritable(getFloatFromFieldVector(item,from));
case Double:
return new DoubleWritable(getDoubleFromFieldVector(item,from));
case Boolean:
BitVector bitVector = (BitVector) from;
return new BooleanWritable(bitVector.get(item) > 0);
case Categorical:
VarCharVector varCharVector = (VarCharVector) from;
return new Text(varCharVector.get(item));
case String:
VarCharVector varCharVector2 = (VarCharVector) from;
return new Text(varCharVector2.get(item));
case Time:
//TODO: need to look at closer
return new LongWritable(getLongFromFieldVector(item,from));
case NDArray:
VarBinaryVector valueVector = (VarBinaryVector) from;
byte[] bytes = valueVector.get(item);
ByteBuffer direct = ByteBuffer.allocateDirect(bytes.length);
direct.put(bytes);
INDArray fromTensor = BinarySerde.toArray(direct);
return new NDArrayWritable(fromTensor);
default:
throw new IllegalArgumentException("Illegal type " + from.getClass().getName());
}
}
private static int getIntFromFieldVector(int row,FieldVector fieldVector) {
if(fieldVector instanceof UInt4Vector) {
UInt4Vector uInt4Vector = (UInt4Vector) fieldVector;
return uInt4Vector.get(row);
}
else if(fieldVector instanceof IntVector) {
IntVector intVector = (IntVector) fieldVector;
return intVector.get(row);
}
throw new IllegalArgumentException("Illegal vector type for int " + fieldVector.getClass().getName());
}
private static long getLongFromFieldVector(int row,FieldVector fieldVector) {
if(fieldVector instanceof UInt8Vector) {
UInt8Vector uInt4Vector = (UInt8Vector) fieldVector;
return uInt4Vector.get(row);
}
else if(fieldVector instanceof IntVector) {
BigIntVector intVector = (BigIntVector) fieldVector;
return intVector.get(row);
}
else if(fieldVector instanceof TimeStampMilliVector) {
TimeStampMilliVector timeStampMilliVector = (TimeStampMilliVector) fieldVector;
return timeStampMilliVector.get(row);
}
else if(fieldVector instanceof BigIntVector) {
BigIntVector bigIntVector = (BigIntVector) fieldVector;
return bigIntVector.get(row);
}
else if (fieldVector instanceof DateMilliVector) {
DateMilliVector dateMilliVector = (DateMilliVector) fieldVector;
return dateMilliVector.get(row);
}
else if(fieldVector instanceof TimeMilliVector) {
TimeMilliVector timeMilliVector = (TimeMilliVector) fieldVector;
return timeMilliVector.get(row);
}
else if(fieldVector instanceof TimeStampMicroVector) {
TimeStampMicroVector timeStampMicroVector = (TimeStampMicroVector) fieldVector;
return timeStampMicroVector.get(row);
}
else if(fieldVector instanceof TimeSecVector) {
TimeSecVector timeSecVector = (TimeSecVector) fieldVector;
return timeSecVector.get(row);
}
else if(fieldVector instanceof TimeStampMilliTZVector) {
TimeStampMilliTZVector timeStampMilliTZVector = (TimeStampMilliTZVector) fieldVector;
return timeStampMilliTZVector.get(row);
}
else if(fieldVector instanceof TimeStampNanoTZVector) {
TimeStampNanoTZVector timeStampNanoTZVector = (TimeStampNanoTZVector) fieldVector;
return timeStampNanoTZVector.get(row);
}
else if(fieldVector instanceof TimeStampMicroTZVector) {
TimeStampMicroTZVector timeStampMicroTZVector = (TimeStampMicroTZVector) fieldVector;
return timeStampMicroTZVector.get(row);
}
else {
throw new UnsupportedOperationException();
}
}
private static double getDoubleFromFieldVector(int row,FieldVector fieldVector) {
if(fieldVector instanceof Float8Vector) {
Float8Vector uInt4Vector = (Float8Vector) fieldVector;
return uInt4Vector.get(row);
}
throw new IllegalArgumentException("Illegal vector type for int " + fieldVector.getClass().getName());
}
private static float getFloatFromFieldVector(int row,FieldVector fieldVector) {
if(fieldVector instanceof Float4Vector) {
Float4Vector uInt4Vector = (Float4Vector) fieldVector;
return uInt4Vector.get(row);
}
throw new IllegalArgumentException("Illegal vector type for int " + fieldVector.getClass().getName());
}
private static ArrowWritableRecordBatch asDataVecBatch(ArrowRecordBatch arrowRecordBatch, Schema schema, VectorSchemaRoot vectorLoader) {
//iterate column wise over the feature vectors, returning entries
List<FieldVector> fieldVectors = new ArrayList<>();
for(int j = 0; j < schema.numColumns(); j++) {
String name = schema.getName(j);
FieldVector fieldVector = vectorLoader.getVector(name);
fieldVectors.add(fieldVector);
}
ArrowWritableRecordBatch ret = new ArrowWritableRecordBatch(fieldVectors, schema);
ret.setArrowRecordBatch(arrowRecordBatch);
return ret;
}
}