使用 SQL 操作 Dask Dataframe
目录
实时 Notebook
您可以在实时会话中运行此 notebook ,或在 Github 上查看。
使用 SQL 操作 Dask Dataframe¶
Dask-SQL 是一个开源项目和 Python 包,它利用 Apache Calcite 为 Dask dataframe 操作提供 SQL 前端,使 SQL 用户无需广泛了解 dataframe API 即可利用 Dask 的分布式能力。
[1]:
! pip install dask-sql
Collecting dask-sql
Downloading dask_sql-2022.6.0-py3-none-any.whl (21.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 71.0 MB/s eta 0:00:00
Collecting uvicorn>=0.11.3
Downloading uvicorn-0.18.2-py3-none-any.whl (57 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.0/57.0 KB 12.7 MB/s eta 0:00:00
Requirement already satisfied: tabulate in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask-sql) (0.8.9)
Requirement already satisfied: nest-asyncio in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask-sql) (1.5.5)
Collecting tzlocal>=2.1
Downloading tzlocal-4.2-py3-none-any.whl (19 kB)
Requirement already satisfied: pandas>=1.0.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask-sql) (1.4.2)
Requirement already satisfied: pygments in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask-sql) (2.12.0)
Requirement already satisfied: prompt-toolkit in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask-sql) (3.0.29)
Collecting jpype1>=1.0.2
Downloading JPype1-1.4.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl (453 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 453.8/453.8 KB 69.3 MB/s eta 0:00:00
Requirement already satisfied: dask[dataframe,distributed]<=2022.5.2,>=2022.3.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask-sql) (2022.5.0)
Collecting fastapi>=0.61.1
Downloading fastapi-0.79.0-py3-none-any.whl (54 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.6/54.6 KB 13.5 MB/s eta 0:00:00
Requirement already satisfied: packaging>=20.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (21.3)
Requirement already satisfied: fsspec>=0.6.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (2022.3.0)
Requirement already satisfied: toolz>=0.8.2 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (0.11.2)
Requirement already satisfied: cloudpickle>=1.1.1 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (2.0.0)
Requirement already satisfied: pyyaml>=5.3.1 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (6.0)
Requirement already satisfied: partd>=0.3.10 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (1.2.0)
Requirement already satisfied: numpy>=1.18 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (1.22.3)
Requirement already satisfied: distributed==2022.05.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (2022.5.0)
Requirement already satisfied: jinja2 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (3.1.1)
Requirement already satisfied: urllib3 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (1.26.9)
Requirement already satisfied: tblib>=1.6.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (1.7.0)
Requirement already satisfied: tornado>=6.0.3 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (6.1)
Requirement already satisfied: sortedcontainers!=2.0.0,!=2.0.1 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (2.4.0)
Requirement already satisfied: msgpack>=0.6.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (1.0.3)
Requirement already satisfied: locket>=1.0.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (1.0.0)
Requirement already satisfied: zict>=0.1.3 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (2.2.0)
Requirement already satisfied: click>=6.6 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (8.1.3)
Requirement already satisfied: psutil>=5.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (5.9.0)
Requirement already satisfied: pydantic!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<2.0.0,>=1.6.2 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from fastapi>=0.61.1->dask-sql) (1.9.1)
Collecting starlette==0.19.1
Downloading starlette-0.19.1-py3-none-any.whl (63 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.3/63.3 KB 16.2 MB/s eta 0:00:00
Requirement already satisfied: anyio<5,>=3.4.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from starlette==0.19.1->fastapi>=0.61.1->dask-sql) (3.5.0)
Requirement already satisfied: typing-extensions>=3.10.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from starlette==0.19.1->fastapi>=0.61.1->dask-sql) (4.2.0)
Requirement already satisfied: python-dateutil>=2.8.1 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from pandas>=1.0.0->dask-sql) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from pandas>=1.0.0->dask-sql) (2022.1)
Collecting pytz-deprecation-shim
Downloading pytz_deprecation_shim-0.1.0.post0-py2.py3-none-any.whl (15 kB)
Collecting h11>=0.8
Downloading h11-0.13.0-py3-none-any.whl (58 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.2/58.2 KB 18.7 MB/s eta 0:00:00
Requirement already satisfied: wcwidth in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from prompt-toolkit->dask-sql) (0.2.5)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from packaging>=20.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (3.0.8)
Requirement already satisfied: six>=1.5 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas>=1.0.0->dask-sql) (1.16.0)
Collecting tzdata
Downloading tzdata-2022.1-py2.py3-none-any.whl (339 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 339.5/339.5 KB 61.7 MB/s eta 0:00:00
Requirement already satisfied: sniffio>=1.1 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from anyio<5,>=3.4.0->starlette==0.19.1->fastapi>=0.61.1->dask-sql) (1.2.0)
Requirement already satisfied: idna>=2.8 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from anyio<5,>=3.4.0->starlette==0.19.1->fastapi>=0.61.1->dask-sql) (3.3)
Requirement already satisfied: heapdict in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from zict>=0.1.3->distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (1.0.1)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages (from jinja2->distributed==2022.05.0->dask[dataframe,distributed]<=2022.5.2,>=2022.3.0->dask-sql) (2.1.1)
Installing collected packages: tzdata, jpype1, h11, uvicorn, starlette, pytz-deprecation-shim, tzlocal, fastapi, dask-sql
Successfully installed dask-sql-2022.6.0 fastapi-0.79.0 h11-0.13.0 jpype1-1.4.0 pytz-deprecation-shim-0.1.0.post0 starlette-0.19.1 tzdata-2022.1 tzlocal-4.2 uvicorn-0.18.2
设置 Dask 集群¶
设置 Dask Cluster 是可选的,但通过让我们访问 GPU、远程机器、常见云提供商等上的 Dask worker,可以极大地扩展分布式计算的选项。此外,将集群连接到 Dask Client 将使我们能够访问仪表板,可用于监控活动计算的进度并诊断问题。
对于本 notebook,我们将创建一个本地集群并将其连接到客户端。客户端创建后,将显示与其关联的仪表板链接,可在后续计算过程中查看该仪表板。
[2]:
from dask.distributed import Client
client = Client(n_workers=2, threads_per_worker=2, memory_limit='1GB')
client
[2]:
客户端
客户端-2c014484-0de0-11ed-9c67-000d3a8f7959
连接方法: Cluster object | 集群类型: distributed.LocalCluster |
仪表板: http://127.0.0.1:8787/status |
集群信息
LocalCluster
dec5a19f
仪表板: http://127.0.0.1:8787/status | Worker 2 |
总线程数 4 | 总内存: 1.86 GiB |
状态: 运行中 | 使用进程: True |
调度器信息
调度器
调度器-93f0bc5c-2d81-4dc1-b2c8-72e6c702a5b9
通信地址: tcp://127.0.0.1:38331 | Worker 2 |
仪表板: http://127.0.0.1:8787/status | 总线程数 4 |
启动时间: 刚刚 | 总内存: 1.86 GiB |
Worker
Worker: 0
通信地址: tcp://127.0.0.1:45347 | 总线程数 2 |
仪表板: http://127.0.0.1:37117/status | 内存: 0.93 GiB |
Nanny: tcp://127.0.0.1:43225 | |
本地目录: /home/runner/work/dask-examples/dask-examples/dask-worker-space/worker-pd8kj694 |
Worker: 1
通信地址: tcp://127.0.0.1:36979 | 总线程数 2 |
仪表板: http://127.0.0.1:45797/status | 内存: 0.93 GiB |
Nanny: tcp://127.0.0.1:41733 | |
本地目录: /home/runner/work/dask-examples/dask-examples/dask-worker-space/worker-e6d8x_if |
创建上下文¶
一个 dask_sql.Context
是 Python 中 SQL 数据库的等价物,用作注册 SQL 查询中使用的所有表和函数以及执行查询本身的接口。典型用法是创建一个单独的 Context
并在 Python 脚本或 notebook 运行期间使用它。
[3]:
from dask_sql import Context
c = Context()
/usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages/dask_sql/java.py:39: UserWarning: You are running in a conda environment, but the JAVA_PATH is not using it. If this is by mistake, set $JAVA_HOME to /usr/share/miniconda3/envs/dask-examples, instead of /usr/lib/jvm/temurin-11-jdk-amd64.
warnings.warn(
加载和注册数据¶
Context
创建后,有多种方式可以在其中注册表。最简单的方式是通过 create_table
方法,该方法接受多种输入类型,Dask-SQL 会根据这些类型推断表创建方法。支持的输入类型包括:
Dask / Pandas 风格的 dataframes
本地或远程数据集的字符串位置
通过 PyHive 或 SQLAlchemy 提供的 Apache Hive 表
还可以通过提供 format
显式指定输入类型。注册时,可以选择通过传递 persist=True
将表持久化到内存中,这可以极大地加速对同一表的重复查询,但代价是将整个表加载到内存中。更多信息请参阅 数据加载和输入。
[4]:
import pandas as pd
from dask.datasets import timeseries
# register and persist a dask table
ddf = timeseries()
c.create_table("dask", ddf, persist=True)
# register a pandas table (implicitly converted to a dask table)
df = pd.DataFrame({"a": [1, 2, 3]})
c.create_table("pandas", df)
# register a table from local storage; kwargs are passed on to the underlying table creation method
c.create_table(
"local",
"surveys/data/2021-user-survey-results.csv.gz",
format="csv",
parse_dates=['Timestamp'],
blocksize=None
)
表也可以通过 SQL CREATE TABLE WITH
或 CREATE TABLE AS
语句,使用 sql
方法注册。
[5]:
# replace our table from local storage
c.sql("""
CREATE OR REPLACE TABLE
"local"
WITH (
location = 'surveys/data/2021-user-survey-results.csv.gz',
format = 'csv',
parse_dates = ARRAY [ 'Timestamp' ]
)
""")
# create a new table from a SQL query
c.sql("""
CREATE TABLE filtered AS (
SELECT id, name FROM dask WHERE name = 'Zelda'
)
""")
/usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages/dask/dataframe/io/csv.py:533: UserWarning: Warning gzip compression does not support breaking apart files
Please ensure that each individual file can fit in memory and
use the keyword ``blocksize=None to remove this message``
Setting ``blocksize=None``
warn(
所有注册的表可以使用 SHOW TABLES
语句列出。
[6]:
c.sql("SHOW TABLES FROM root").compute()
[6]:
表 | |
---|---|
0 | dask |
1 | pandas |
2 | local |
3 | filtered |
Dask-SQL 目前提供实验性的 GPU 支持,由 RAPIDS 开源 GPU 数据科学库套件提供支持。输入支持目前仅限于 Dask / Pandas 风格的 dataframes 以及本地/远程存储中的数据,尽管大多数查询运行没有问题,但用户应预料到可能存在一些错误或未定义的行为。要注册一个表并将其标记为在 GPU 上使用,可以将 gpu=True
传递给标准的 create_table
调用,或其等效的 CREATE TABLE WITH
查询(请注意,这需要 cuDF 和 Dask-cuDF)。
# register a dask table for use on GPUs (not possible in this binder)
c.create_table("gpu_dask", ddf, gpu=True)
# load in a table from disk using GPU-accelerated IO operations
c.sql("""
CREATE TABLE
"gpu_local"
WITH (
location = 'surveys/data/2021-user-survey-results.csv.gz',
format = 'csv',
parse_dates = ARRAY [ 'Timestamp' ],
gpu = True
)
""")
查询数据¶
当调用 sql
方法时,Dask-SQL 将查询交给 Apache Calcite 转换为关系代数——本质上是必须执行的 SQL 任务列表才能获得结果。任何查询的关系代数都可以使用 explain
方法直接查看。
[7]:
print(c.explain("SELECT AVG(x) FROM dask"))
DaskProject(EXPR$0=[/(CAST(CASE(=($1, 0), null:DOUBLE, $0)):DECIMAL(19, 15), $1)]): rowcount = 10.0, cumulative cost = {122.5 rows, 111.0 cpu, 0.0 io}, id = 83
DaskAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT($2)]): rowcount = 10.0, cumulative cost = {112.5 rows, 101.0 cpu, 0.0 io}, id = 82
DaskTableScan(table=[[root, dask]]): rowcount = 100.0, cumulative cost = {100.0 rows, 101.0 cpu, 0.0 io}, id = 77
从这里,这种关系代数随后被转换为 Dask 计算图,该计算图最终返回(或者对于 CREATE TABLE
语句,隐式分配)一个 Dask dataframe。
[8]:
c.sql("SELECT AVG(x) FROM dask")
[8]:
AVG("dask"."x") | |
---|---|
npartitions=1 | |
float64 | |
... |
Dask dataframes 是惰性的,这意味着在它们创建时,所有依赖任务都尚未执行。要实际执行这些任务并获得结果,我们必须调用 compute
。
[9]:
c.sql("SELECT AVG(x) FROM dask").compute()
[9]:
AVG("dask"."x") | |
---|---|
0 | -0.000302 |
查看仪表板,我们可以看到执行此查询已触发了一些 Dask 计算。
由于查询的返回值是 Dask dataframe,因此也可以使用 Dask 的 dataframe API 对其进行后续操作。如果我们要对 dataframe 执行一些 Dask 无法实现的复杂操作,然后通过 dataframe API 轻松表达一些更简单的操作,这可能很有用。
[10]:
# perform a multi-column sort that isn't possible in Dask
res = c.sql("""
SELECT * FROM dask ORDER BY name ASC, id DESC, x ASC
""")
# now do some follow groupby aggregations
res.groupby("name").agg({"x": "sum", "y": "mean"}).compute()
[10]:
x | y | |
---|---|---|
name | ||
Alice | -249.383593 | 0.001241 |
Bob | 160.839932 | 0.000056 |
Charlie | -77.458027 | -0.001389 |
Dan | 141.385152 | -0.001548 |
Edith | -33.965445 | -0.000867 |
Frank | 31.380364 | -0.000966 |
George | 291.711276 | -0.002320 |
Hannah | 76.193943 | -0.001283 |
Ingrid | 69.657261 | -0.001849 |
Jerry | -35.406853 | -0.002052 |
Kevin | -199.853191 | 0.000221 |
Laura | 98.363175 | -0.001911 |
Michael | -100.410534 | 0.004294 |
Norbert | 189.525214 | -0.000738 |
Oliver | -251.094045 | -0.000164 |
Patricia | -37.815014 | 0.003536 |
Quinn | -137.963034 | -0.001342 |
Ray | -274.337917 | 0.004108 |
Sarah | -237.457164 | 0.001387 |
Tim | 67.416750 | 0.001667 |
Ursula | -188.578720 | 0.002330 |
Victor | -60.309784 | -0.000196 |
Wendy | 128.743367 | 0.000112 |
Xavier | -158.350232 | -0.001734 |
Yvonne | 43.986670 | 0.001555 |
Zelda | -38.438229 | 0.001045 |
自定义函数和聚合¶
当标准 SQL 功能不足时,可以注册自定义函数以用于查询。这些函数可以归类为以下之一:
列式函数
行式函数
聚合
列式函数¶
列式函数可以将列或字面值作为输入,并返回长度相同的列。列式函数可以使用 register_function
方法在 Context
中注册。
[11]:
import numpy as np
def f(x):
return x ** 2
c.register_function(f, "f", [("x", np.float64)], np.float64)
函数注册需要以下输入:
可调用函数
用于在查询中引用的函数名称
一个元组列表,表示输入变量及其各自的类型,可以是 Pandas 或 NumPy 类型
输出列的类型
函数注册后,可以像任何其他标准 SQL 函数一样调用它。
[12]:
c.sql("SELECT F(x) FROM dask").compute()
[12]:
"F"("dask"."x") | |
---|---|
timestamp | |
2000-01-01 00:00:00 | 0.408645 |
2000-01-01 00:00:01 | 0.497901 |
2000-01-01 00:00:02 | 0.064370 |
2000-01-01 00:00:03 | 0.421497 |
2000-01-01 00:00:04 | 0.304109 |
... | ... |
2000-01-30 23:59:55 | 0.691240 |
2000-01-30 23:59:56 | 0.499867 |
2000-01-30 23:59:57 | 0.049903 |
2000-01-30 23:59:58 | 0.004089 |
2000-01-30 23:59:59 | 0.490209 |
2592000 行 × 1 列
行式函数¶
在某些情况下,编写一个处理类似字典的 row
对象(即行式函数)的自定义函数可能更容易。这些函数也可以使用 register_function
注册,通过传递 row_udf=True
,并以与列式函数相同的方式使用。
[13]:
def g(row):
if row["x"] > row["y"]:
return row["x"] - row["y"]
return row["y"] - row["x"]
c.register_function(g, "g", [("x", np.float64), ("y", np.float64)], np.float64, row_udf=True)
c.sql("SELECT G(x, y) FROM dask").compute()
[13]:
"G"("dask"."x", "dask"."y") | |
---|---|
timestamp | |
2000-01-01 00:00:00 | 0.446911 |
2000-01-01 00:00:01 | 0.900878 |
2000-01-01 00:00:02 | 0.052787 |
2000-01-01 00:00:03 | 0.454549 |
2000-01-01 00:00:04 | 1.157125 |
... | ... |
2000-01-30 23:59:55 | 1.603634 |
2000-01-30 23:59:56 | 1.389727 |
2000-01-30 23:59:57 | 0.671131 |
2000-01-30 23:59:58 | 0.773367 |
2000-01-30 23:59:59 | 0.023842 |
2592000 行 × 1 列
请注意,与直接使用指定的列和字面值作为输入调用的列式函数不同,行式函数使用 apply
调用,其性能可能取决于底层的 dataframe 库(例如 Pandas, cuDF)和函数本身,因此性能可能无法预测。
聚合¶
聚合接受单个列作为输入并返回单个值——因此,它们只能用于减少 GROUP BY
查询的结果。可以使用 register_aggregation
方法注册聚合,该方法在功能上类似于 register_function
,但接受 Dask Aggregation 作为输入,而不是可调用函数。
[14]:
import dask.dataframe as dd
my_sum = dd.Aggregation("my_sum", lambda x: x.sum(), lambda x: x.sum())
c.register_aggregation(my_sum, "my_sum", [("x", np.float64)], np.float64)
c.sql("SELECT MY_SUM(x) FROM dask").compute()
[14]:
"MY_SUM"("dask"."x") | |
---|---|
0 | -781.618678 |
SQL 中的机器学习¶
Dask-SQL 支持模型训练和预测,从而可以通过 Python 和 SQL 的灵活组合实现机器学习工作流程。模型可以通过 register_model
方法或 CREATE MODEL
语句在 Context
中注册。
[15]:
from dask_ml.linear_model import LinearRegression
from sklearn.ensemble import GradientBoostingClassifier
# create a dask-ml model and train it
model = GradientBoostingClassifier()
data = c.sql("SELECT x, y, x * y > 0 AS target FROM dask LIMIT 50")
model.fit(data[["x", "y"]], data["target"])
# register this model in the context
c.register_model("python_model", model, training_columns=["x", "y"])
# create and train a model directly from SQL
c.sql("""
CREATE MODEL sql_model WITH (
model_class = 'sklearn.ensemble.GradientBoostingClassifier',
wrap_predict = True,
target_column = 'target'
) AS (
SELECT x, y, x * y > 0 AS target
FROM dask
LIMIT 50
)
""")
注册的模型必须遵循 scikit-learn 接口,实现 predict
方法。与表一样,所有注册的模型都可以使用 SHOW MODEL
语句列出。
[16]:
c.sql("SHOW MODELS").compute()
[16]:
模型 | |
---|---|
0 | python_model |
1 | sql_model |
接下来,可以使用 PREDICT
关键字作为 SELECT
查询的一部分,使用这些模型进行预测。
[17]:
c.sql("""
SELECT * FROM PREDICT (
MODEL sql_model,
SELECT x, y, x * y > 0 AS actual FROM dask
OFFSET 50
)
""").compute()
[17]:
x | y | actual | target | |
---|---|---|---|---|
timestamp | ||||
2000-01-01 00:00:50 | -0.508541 | -0.018462 | True | True |
2000-01-01 00:00:51 | 0.652920 | -0.847008 | False | False |
2000-01-01 00:00:52 | -0.779734 | 0.117797 | False | False |
2000-01-01 00:00:53 | 0.360605 | -0.965205 | False | False |
2000-01-01 00:00:54 | -0.475373 | 0.652320 | False | False |
... | ... | ... | ... | ... |
2000-01-30 23:59:55 | -0.831409 | 0.772225 | False | False |
2000-01-30 23:59:56 | -0.707013 | 0.682714 | False | False |
2000-01-30 23:59:57 | -0.223391 | 0.447740 | False | False |
2000-01-30 23:59:58 | 0.063943 | -0.709424 | False | False |
2000-01-30 23:59:59 | -0.700149 | -0.723991 | True | True |
2591950 行 × 4 列
[ ]: