I'm trying to get a Daft UDF to output a `StructTy...
# general
a
I'm trying to get a Daft UDF to output a
StructType
but struggling with
PanicException: ('not implemented: List casting not implemented for dtype: Struct[top_prediction: Utf8, confidence: Float32]',)
. Examples in thread
The code below works but returns a List of strings
Copy code
@udf(return_dtype=DataType.fixed_size_list(dtype=DataType.string() , size=2))
class ClassifyImages:
    
    def __init__(self):
        # Perform expensive initializations - create and load the pre-trained model
        self.model = torch.hub.load(
            "NVIDIA/DeepLearningExamples:torchhub", "nvidia_resnet50", pretrained=True
        )
        self.utils = torch.hub.load(
            "NVIDIA/DeepLearningExamples:torchhub", "nvidia_convnets_processing_utils"
        )
        self.model.eval().to(torch.device("cpu"))
    
    def __call__(self, tensors):
        tensors = torch.tensor(np.array(tensors.to_pylist())) #get tensors into correct format
        
        with torch.no_grad():
            output = torch.nn.functional.softmax(self.model(tensors), dim=1)

        results = self.utils.pick_n_best(predictions=output, n=1)
        
        return [result[0] for result in results]
I would like to return a Struct Type instead:
DataType.struct({"top_prediction": DataType.string(), "confidence": DataType.float32()})
I've tried this but fails with
ArrowTypeError: Expected bytes, got a 'float' object
Copy code
@udf(return_dtype=DataType.struct({"top_prediction": DataType.string(), "confidence": DataType.float32()}))
class ClassifyImages:
    
    def __init__(self):
        # Perform expensive initializations - create and load the pre-trained model
        self.model = torch.hub.load(
            "NVIDIA/DeepLearningExamples:torchhub", "nvidia_resnet50", pretrained=True
        )
        self.utils = torch.hub.load(
            "NVIDIA/DeepLearningExamples:torchhub", "nvidia_convnets_processing_utils"
        )
        self.model.eval().to(torch.device("cpu"))
    
    def __call__(self, tensors):
        tensors = torch.tensor(np.array(tensors.to_pylist())) #get tensors into correct format
        
        with torch.no_grad():
            output = torch.nn.functional.softmax(self.model(tensors), dim=1)

        results = self.utils.pick_n_best(predictions=output, n=1)
        list_res = [result[0] for result in results]
        new_list = []
        for pred, conf in list_res:
            conf = float(conf.strip('%'))
            new_list.append([pred,conf])
        
        return new_list
full traceback:
Copy code
---------------------------------------------------------------------------
ArrowTypeError                            Traceback (most recent call last)
Cell In[66], line 3
      1 df_classified = df_pre.with_column("classify_breed", ClassifyImages(daft.col("transformed_tensor")))
----> 3 df_classified.select("dog_name", "image", "classify_breed").show()

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/api_annotations.py:26, in DataframePublicAPI.<locals>._wrap(*args, **kwargs)
     24 type_check_function(func, *args, **kwargs)
     25 timed_method = time_df_method(func)
---> 26 return timed_method(*args, **kwargs)

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/analytics.py:189, in time_df_method.<locals>.tracked_method(*args, **kwargs)
    187 start = time.time()
    188 try:
--> 189     result = method(*args, **kwargs)
    190 except Exception as e:
    191     _ANALYTICS_CLIENT.track_df_method_call(
    192         method_name=method.__name__, duration_seconds=time.time() - start, error=str(type(e).__name__)
    193     )

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/dataframe/dataframe.py:1874, in DataFrame.show(self, n)
   1861 @DataframePublicAPI
   1862 def show(self, n: int = 8) -> None:
   1863     """Executes enough of the DataFrame in order to display the first ``n`` rows
   1864 
   1865     If IPython is installed, this will use IPython's `display` utility to pretty-print in a
   (...)
   1872         n: number of rows to show. Defaults to 8.
   1873     """
-> 1874     dataframe_display = self._construct_show_display(n)
   1875     try:
   1876         from IPython.display import display

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/dataframe/dataframe.py:1831, in DataFrame._construct_show_display(self, n)
   1829 tables = []
   1830 seen = 0
-> 1831 for table in get_context().runner().run_iter_tables(builder, results_buffer_size=1):
   1832     tables.append(table)
   1833     seen += len(table)

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/runners/pyrunner.py:198, in PyRunner.run_iter_tables(self, builder, results_buffer_size)
    195 def run_iter_tables(
    196     self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None
    197 ) -> Iterator[MicroPartition]:
--> 198     for result in self.run_iter(builder, results_buffer_size=results_buffer_size):
    199         yield result.partition()

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/runners/pyrunner.py:193, in PyRunner.run_iter(self, builder, results_buffer_size)
    191 with profiler("profile_PyRunner.run_{datetime.now().isoformat()}.json"):
    192     results_gen = self._physical_plan_to_partitions(tasks)
--> 193     yield from results_gen

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/runners/pyrunner.py:293, in PyRunner._physical_plan_to_partitions(self, plan)
    291 del inflight_tasks_resources[done_id]
    292 done_task = inflight_tasks.pop(done_id)
--> 293 materialized_results = done_future.result()
    295 pbar.mark_task_done(done_task)
    297 logger.debug(
    298     "Task completed: %s -> <%s partitions>",
    299     done_id,
    300     len(materialized_results),
    301 )

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/concurrent/futures/_base.py:449, in Future.result(self, timeout)
    447     raise CancelledError()
    448 elif self._state == FINISHED:
--> 449     return self.__get_result()
    451 self._condition.wait(timeout)
    453 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/concurrent/futures/_base.py:401, in Future.__get_result(self)
    399 if self._exception:
    400     try:
--> 401         raise self._exception
    402     finally:
    403         # Break a reference cycle with the exception in self._exception
    404         self = None

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/concurrent/futures/thread.py:58, in _WorkItem.run(self)
     55     return
     57 try:
---> 58     result = self.fn(*self.args, **self.kwargs)
     59 except BaseException as exc:
     60     self.future.set_exception(exc)

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/runners/pyrunner.py:347, in PyRunner.build_partitions(instruction_stack, partitions, final_metadata)
    340 @staticmethod
    341 def build_partitions(
    342     instruction_stack: list[Instruction],
    343     partitions: list[MicroPartition],
    344     final_metadata: list[PartialPartitionMetadata],
    345 ) -> list[MaterializedResult[MicroPartition]]:
    346     for instruction in instruction_stack:
--> 347         partitions = instruction.run(partitions)
    348     return [
    349         PyMaterializedResult(part, PartitionMetadata.from_table(part).merge_with_partial(partial))
    350         for part, partial in zip(partitions, final_metadata)
    351     ]

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/execution/execution_step.py:510, in Project.run(self, inputs)
    509 def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
--> 510     return self._project(inputs)

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/execution/execution_step.py:514, in Project._project(self, inputs)
    512 def _project(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
    513     [input] = inputs
--> 514     return [input.eval_expression_list(self.projection)]

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/table/micropartition.py:176, in MicroPartition.eval_expression_list(self, exprs)
    174 assert all(isinstance(e, Expression) for e in exprs)
    175 pyexprs = [e._expr for e in exprs]
--> 176 return MicroPartition._from_pymicropartition(self._micropartition.eval_expression_list(pyexprs))

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/udf.py:126, in PartialUDF.__call__(self, evaluated_expressions)
    124         return Series.from_pylist(result, name=name, pyobj="force")._series
    125     else:
--> 126         return Series.from_pylist(result, name=name, pyobj="allow").cast(self.udf.return_dtype)._series
    127 elif _NUMPY_AVAILABLE and isinstance(result, np.ndarray):
    128     return Series.from_numpy(result, name=name).cast(self.udf.return_dtype)._series

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/daft/series.py:124, in Series.from_pylist(data, name, pyobj)
    121     return Series._from_pyseries(pys)
    123 try:
--> 124     arrow_array = pa.array(data)
    125     return Series.from_arrow(arrow_array, name=name)
    126 except pa.lib.ArrowInvalid:

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/pyarrow/array.pxi:355, in pyarrow.lib.array()

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/pyarrow/array.pxi:42, in pyarrow.lib._sequence_to_array()

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/pyarrow/error.pxi:154, in pyarrow.lib.pyarrow_internal_check_status()

File ~/miniforge3/envs/daft-pytorch/lib/python3.11/site-packages/pyarrow/error.pxi:91, in pyarrow.lib.check_status()

ArrowTypeError: Expected bytes, got a 'float' object
j
For
return_dtype=DataType.struct({"top_prediction": DataType.string(), "confidence": DataType.float32()})
, the returned data should look like this:
Copy code
[
    {"top_prediction": "foo", "confidence": 0.5},
    {"top_prediction": "foo", "confidence": 0.5},
    {"top_prediction": "foo", "confidence": 0.5},
    ...
]
Otherwise, Daft will struggle to coerce your data into the specified type!
👍 1